
import os
import configparser
import json

from core.formalization.rl.rl_manager import RLManager
from core.formalization.action_space import ActionType
from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary
from core.formalization.symbol_manager import SymbolManager
from llm.llm_const import (
    MODEL_ID_SILICONFLOW_DEEPSEEK_V3,
    MODEL_ID_SILICONFLOW_BGE_M3,
)

project_root = os.path.dirname(__file__)
logger = Logger(__name__, 'dev')

def read_config(local_config_path, default_config_path):
    config = configparser.ConfigParser()

    if os.path.exists(local_config_path):
        config.read(local_config_path)
        print(f'Loaded configuration from {local_config_path}')
    else:
        config.read(default_config_path)
        print(f'Loaded configuration from {default_config_path}')

    return config

test_config = read_config(
    os.path.join(project_root, 'local_config.ini'),
    os.path.join(project_root, 'default_config.ini'),
)

def _init_llm(model_id, api_key, model_save_path=None, model_cache_path=None):
    llm_config = {
        'model_id': model_id,
        'model_save_path': model_save_path,
        'model_cache_path': model_cache_path,
        'api_key': api_key,
    }
    return LLMWrapper(config=llm_config, logger=logger)

huggingface_cache_dir = test_config.get('BASE', 'huggingface-cache')
huggingface_save_dir = test_config.get('BASE', 'huggingface-save')
siliconflow_api_key = test_config.get('API_KEY', 'siliconflow')
llm_type = test_config.get('EXP', 'llm')
embedding = test_config.get('EXP', 'embedding')

model_save_path = os.path.join(huggingface_save_dir, 'model')
model_cache_path = os.path.join(huggingface_cache_dir, 'model')

if llm_type == "ds":
    model_id = MODEL_ID_SILICONFLOW_DEEPSEEK_V3
    model_api_key = siliconflow_api_key

if embedding == "bge_m3":
    embedding_id = MODEL_ID_SILICONFLOW_BGE_M3
    embedding_api_key = siliconflow_api_key

llm = _init_llm(
    model_id,
    model_api_key,
    model_save_path,
    model_cache_path,
)
llm.init()
assets_dir = os.path.join(project_root, 'assets')
cache_dir = os.path.join(project_root, 'cache', llm_type)
output_dir = os.path.join(project_root, 'output', llm_type)
config = {
    'cache_dir': cache_dir,
    'output_dir': output_dir,
    'api_generate_model_id': model_id,
    'api_embedding_model_id': embedding_id,
    'generate_api_key': model_api_key,
    'embedding_api_key': embedding_api_key,
    'model_name': "20250828_103838_final",
}

auxiliary = Auxiliary(logger, config)
symbol_manager = SymbolManager(logger, auxiliary, config)
rl_manager = RLManager(logger, llm, auxiliary, symbol_manager, config)

filepath = os.path.join(assets_dir, "category_adv_bench_train.jsonl")
queries = []
with open(filepath, 'r', encoding='utf-8') as f:
    for line in f:
        info = json.loads(line)
        idx = info['idx']
        query = info['query']
        target = info['target']
        category = info['category']

        primary_category = category['primary_category']
        secondary_category = category['secondary_category']
        queries.append({
            "query": query,
            "target": target,
            "category": f"{primary_category}: {secondary_category}"
        })

rl_manager.train(queries, 500)